/******************************************************************************
 * Copyright (c) 2011-2021, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the
 *       names of its contributors may be used to endorse or promote products
 *       derived from this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 ******************************************************************************/

#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>

#include "fmha.h"

extern at::Tensor &mha_fill(at::Tensor &self, const at::Tensor &start_index);
void set_params(Fused_multihead_attention_fprop_params &params,
                // sizes
                const size_t b, const size_t s, const size_t h, const size_t d,
                // device pointers
                void *qkv_packed_d, void *cu_seqlens_d, void *o_packed_d, void *s_d, float p_dropout) {
  Data_type acc_type = DATA_TYPE_FP32;
  Data_type data_type = DATA_TYPE_FP16;

  // Reset the parameters
  memset(&params, 0, sizeof(params));

  // Set the pointers and strides.
  params.qkv_ptr = qkv_packed_d;
  params.qkv_stride_in_bytes = get_size_in_bytes(h * 3 * d, data_type);
  params.o_ptr = o_packed_d;
  params.o_stride_in_bytes = get_size_in_bytes(h * d, data_type);

  params.cu_seqlens = static_cast<int *>(cu_seqlens_d);

  // S = softmax(P)
  params.s_ptr = s_d;
  params.s_stride_in_bytes = get_size_in_bytes(b * h * s, data_type);

  // Set the dimensions.
  params.b = b;
  params.h = h;
  params.s = s;
  params.d = d;

  // Set the different scale values.
  const float scale_bmm1 = 1.f / sqrtf(d);
  constexpr float scale_softmax = 1.f;
  constexpr float scale_bmm2 = 1.f;

  set_alpha(params.scale_bmm1, scale_bmm1, data_type);
  set_alpha(params.scale_softmax, scale_softmax, acc_type);
  set_alpha(params.scale_bmm2, scale_bmm2, data_type);

  // Set this to probability of keeping an element to simplify things.
  params.p_dropout = 1.f - p_dropout;
  params.rp_dropout = 1.f / params.p_dropout;
  TORCH_CHECK(p_dropout < 1.f);
  set_alpha(params.scale_dropout, params.rp_dropout, data_type);
}

std::vector<at::Tensor> mha_fwd(
    const at::Tensor &qkv,         // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
    const at::Tensor &cu_seqlens,  // b+1
    const float p_dropout, const int max_seq_len, const bool is_training, const bool is_nl, const bool zero_tensors,
    c10::optional<at::Generator> gen_) {
  using namespace torch::indexing;
  auto dprops = at::cuda::getCurrentDeviceProperties();
  TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) ||
              (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0));
  auto stream = at::cuda::getCurrentCUDAStream().stream();
  Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_training, is_nl);

  int seq_len = 512;
  auto launch = &run_fmha_fp16_512_64_sm80;
  if (max_seq_len <= 128) {
    seq_len = 128;
    launch = &run_fmha_fp16_128_64_sm80;
  } else if (max_seq_len <= 256) {
    seq_len = 256;
    launch = &run_fmha_fp16_256_64_sm80;
  } else if (max_seq_len <= 384) {
    seq_len = 384;
    launch = &run_fmha_fp16_384_64_sm80;
  } else if (max_seq_len <= 512) {
    seq_len = 512;
    launch = &run_fmha_fp16_512_64_sm80;
  } else {
    TORCH_CHECK(false);
  }

  TORCH_CHECK(qkv.is_cuda())
  TORCH_CHECK(cu_seqlens.is_cuda())

  TORCH_CHECK(qkv.is_contiguous())
  TORCH_CHECK(cu_seqlens.is_contiguous())

  TORCH_CHECK(cu_seqlens.dim() == 1);
  TORCH_CHECK(qkv.dim() == 4);

  const auto sizes = qkv.sizes();

  TORCH_CHECK(sizes[THREE_DIM] == 3);

  const int batch_size = cu_seqlens.numel() - 1;
  const int total = sizes[TOTAL_DIM];
  const int num_heads = sizes[H_DIM];
  const int head_size = sizes[D_DIM];
  TORCH_CHECK(batch_size > 0);
  TORCH_CHECK(head_size == 64);
  auto opts = qkv.options();

  auto ctx = torch::empty({total, num_heads, head_size}, opts);

  auto s = torch::empty({batch_size, num_heads, seq_len, seq_len}, opts);

  if (zero_tensors) {
    mha_fill(ctx, cu_seqlens.index({Slice(-1, None)}));
  }

  auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());

  set_params(launch_params.params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(),
             ctx.data_ptr(), s.data_ptr(), p_dropout);

  launch(launch_params, /*configure=*/true);
  // number of times random will be generated per thread, to offset philox counter in thc random
  // state
  int64_t counter_offset = launch_params.elts_per_thread;
  at::PhiloxCudaState rng_engine_inputs;

  if (is_training) {
    // See Note [Acquire lock when using random generators]
    std::lock_guard<std::mutex> lock(gen->mutex_);
    launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
  }

  launch(launch_params, /*configure=*/false);

  return {ctx, s};
}

std::vector<at::Tensor> mha_bwd(
    const at::Tensor &dout,        // total x num_heads, x head_size
    const at::Tensor &qkv,         // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
    at::Tensor &softmax,           // b x h x s x s softmax and dmask - will be overwritten with dP
    const at::Tensor &cu_seqlens,  // b+1
    const float p_dropout,         // probability to drop
    const int max_seq_len,         // max sequence length to choose the kernel
    const bool zero_tensors) {
  using namespace torch::indexing;
  auto dprops = at::cuda::getCurrentDeviceProperties();
  TORCH_CHECK((dprops->major == 8 && dprops->minor == 0) || (dprops->major == 9 && dprops->minor == 0) ||
              (dprops->major == 10 && dprops->minor == 0) || (dprops->major == 12 && dprops->minor == 0));
  int seq_len = 512;
  auto launch = &run_fmha_dgrad_fp16_512_64_sm80;
  if (max_seq_len <= 128) {
    seq_len = 128;
    launch = &run_fmha_dgrad_fp16_128_64_sm80;
  } else if (max_seq_len <= 256) {
    seq_len = 256;
    launch = &run_fmha_dgrad_fp16_256_64_sm80;
  } else if (max_seq_len <= 384) {
    seq_len = 384;
    launch = &run_fmha_dgrad_fp16_384_64_sm80;
  } else if (max_seq_len <= 512) {
    seq_len = 512;
    launch = &run_fmha_dgrad_fp16_512_64_sm80;
  } else {
    TORCH_CHECK(false);
  }

  auto stream = at::cuda::getCurrentCUDAStream().stream();

  TORCH_CHECK(qkv.dtype() == torch::kFloat16);
  TORCH_CHECK(dout.dtype() == torch::kFloat16);
  TORCH_CHECK(softmax.dtype() == torch::kFloat16);
  TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);

  TORCH_CHECK(qkv.is_cuda());
  TORCH_CHECK(cu_seqlens.is_cuda());

  TORCH_CHECK(qkv.is_contiguous());
  TORCH_CHECK(cu_seqlens.is_contiguous());

  TORCH_CHECK(cu_seqlens.dim() == 1);
  TORCH_CHECK(qkv.dim() == 4);

  const auto sizes = qkv.sizes();

  TORCH_CHECK(sizes[THREE_DIM] == 3);

  const int batch_size = cu_seqlens.numel() - 1;
  const int num_heads = sizes[H_DIM];
  const int head_size = sizes[D_DIM];
  TORCH_CHECK(batch_size > 0);
  TORCH_CHECK(head_size == 64);

  auto dqkv = torch::empty_like(qkv);

  if (zero_tensors) {
    mha_fill(dqkv, cu_seqlens.index({Slice(-1, None)}));
  }

  Fused_multihead_attention_fprop_params params;

  set_params(params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(),
             dout.data_ptr(),     // we set o_ptr to dout
             softmax.data_ptr(),  // softmax gets overwritten by dP!
             p_dropout);

  // we're re-using these scales
  Data_type acc_type = DATA_TYPE_FP32;
  set_alpha(params.scale_bmm1, 1.f, acc_type);
  set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
  set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
  params.dqkv_ptr = dqkv.data_ptr();

  launch(params, stream);
  return {dqkv, softmax};
}

std::vector<at::Tensor> mha_bwd_nl(
    const at::Tensor &dout,        // total x num_heads, x head_size
    const at::Tensor &qkv,         // total x num_heads x 3 x head_size, total := \sum_{i=0}^{b} s_i
    at::Tensor &softmax,           // b x h x s x s softmax and dmask - will be overwritten with dP
    const at::Tensor &cu_seqlens,  // b+1
    const float p_dropout,         // probability to drop
    const int max_seq_len,         // max sequence length to choose the kernel
    const bool zero_tensors) {
  auto stream = at::cuda::getCurrentCUDAStream().stream();

  TORCH_CHECK(qkv.is_cuda())
  TORCH_CHECK(cu_seqlens.is_cuda())

  TORCH_CHECK(qkv.is_contiguous())
  TORCH_CHECK(cu_seqlens.is_contiguous())

  TORCH_CHECK(cu_seqlens.dim() == 1);

  TORCH_CHECK(qkv.dim() == 4);

  const auto sizes = qkv.sizes();

  TORCH_CHECK(sizes[THREE_DIM] == 3);

  const int batch_size = cu_seqlens.numel() - 1;

  const int total = sizes[TOTAL_DIM];
  const int num_heads = sizes[H_DIM];
  const int head_size = sizes[D_DIM];
  TORCH_CHECK(batch_size > 0);
  TORCH_CHECK(head_size == 64);

  int seq_len = 512;
  auto launch = &run_fmha_dgrad_fp16_512_64_sm80_nl;

  auto opts = qkv.options();

  auto dqkv = torch::empty_like(qkv);

  if (zero_tensors) {
    dqkv.zero_();
  }

  int num_chunks = 2;
  if (batch_size == 1) {
    num_chunks = 4;
  } else if (batch_size == 2) {
    num_chunks = 3;
  }
  auto dkv = torch::empty({total, num_chunks, 2, num_heads, head_size}, opts);

  Fused_multihead_attention_fprop_params params;

  set_params(params, batch_size, seq_len, num_heads, head_size, qkv.data_ptr(), cu_seqlens.data_ptr(),
             dout.data_ptr(),     // o_ptr = dout
             softmax.data_ptr(),  // softmax gets overwritten by dP!
             p_dropout);

  params.dkv_ptr = dkv.data_ptr();

  Data_type acc_type = DATA_TYPE_FP32;
  set_alpha(params.scale_bmm1, 1.f, acc_type);
  set_alpha(params.scale_softmax, 1.f / sqrtf(head_size), acc_type);
  set_alpha(params.scale_bmm2, 1.f, DATA_TYPE_FP16);
  params.dqkv_ptr = dqkv.data_ptr();

  launch(params, num_chunks, stream);

  // SPLIT-K reduction of num_chunks dK, dV parts

  // The equivalent of the following Pytorch code:
  // using namespace torch::indexing;
  // at::Tensor view_out = dqkv.index({Slice(), Slice(1, None, None)});
  // torch::sum_out(view_out, dkv, 1);

  const int hidden_size = num_heads * head_size;
  fmha_run_noloop_reduce(dqkv.data_ptr(), dkv.data_ptr(), cu_seqlens.data_ptr<int>(), hidden_size, batch_size, total,
                         num_chunks, stream);

  return {dqkv, softmax, dkv};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.doc() = "Fused Multi-head Self-attention for BERT";
  m.def("fwd", &mha_fwd, "Forward pass", py::call_guard<py::gil_scoped_release>());
  m.def("bwd", &mha_bwd, "Backward pass", py::call_guard<py::gil_scoped_release>());
  m.def("bwd_nl", &mha_bwd_nl, "Backward pass (small-batch)", py::call_guard<py::gil_scoped_release>());
}
